/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
*
* NVIDIA CORPORATION and its licensors retain all intellectual property
* and proprietary rights in and to this software, related documentation
* and any modifications thereto.  Any use, reproduction, disclosure or
* distribution of this software and related documentation without an express
* license agreement from NVIDIA CORPORATION is strictly prohibited.
*/

/** @file   triangle_bvh.cu
 *  @author Thomas Müller & Alex Evans, NVIDIA
 */

#include <neural-graphics-primitives/common_host.h>
#include <neural-graphics-primitives/triangle_bvh.cuh>

#include <tiny-cuda-nn/gpu_memory.h>

#include <stack>

#ifdef NGP_OPTIX
#  include <optix.h>
#  include <optix_stubs.h>
#  include <optix_function_table_definition.h>
#  include <optix_stack_size.h>

// Custom optix toolchain stuff
#  include "optix/pathescape.h"
#  include "optix/raystab.h"
#  include "optix/raytrace.h"

// Compiled optix program PTX generated by cmake and wrapped in a C
// header by bin2c.
namespace optix_ptx {
	#include <optix_ptx.h>
}
#endif //NGP_OPTIX

namespace ngp {

constexpr float MAX_DIST = 10.0f;

#ifdef NGP_OPTIX
OptixDeviceContext g_optix;

namespace optix {
	bool initialize() {
		static bool ran_before = false;
		static bool is_optix_initialized = false;
		if (ran_before) {
			return is_optix_initialized;
		}

		ran_before = true;

		// Initialize CUDA with a no-op call to the the CUDA runtime API
		CUDA_CHECK_THROW(cudaFree(nullptr));

		try {
			// Initialize the OptiX API, loading all API entry points
			OPTIX_CHECK_THROW(optixInit());

			// Specify options for this context. We will use the default options.
			OptixDeviceContextOptions options = {};

			// Associate a CUDA context (and therefore a specific GPU) with this
			// device context
			CUcontext cuCtx = 0; // NULL means take the current active context

			OPTIX_CHECK_THROW(optixDeviceContextCreate(cuCtx, &options, &g_optix));
		} catch (std::exception& e) {
			tlog::warning() << "OptiX failed to initialize: " << e.what();
			return false;
		}

		is_optix_initialized = true;
		return true;
	}

	template <typename T>
	struct SbtRecord {
		__align__( OPTIX_SBT_RECORD_ALIGNMENT ) char header[OPTIX_SBT_RECORD_HEADER_SIZE];
		T data;
	};

	template <typename T>
	class Program {
	public:
		Program(const char* data, size_t size, OptixDeviceContext optix, bool spheres = false) {
			char log[2048]; // For error reporting from OptiX creation functions
			size_t sizeof_log = sizeof(log);

			// Module from PTX
			OptixModule optix_module = nullptr;
			OptixModuleCompileOptions module_compile_options = {};
			OptixPipelineCompileOptions pipeline_compile_options = {};

			// Pipeline options must be consistent for all modules used in a
			// single pipeline
			pipeline_compile_options.usesMotionBlur = false;

			// This option is important to ensure we compile code which is optimal
			// for our scene hierarchy. We use a single GAS � no instancing or
			// multi-level hierarchies
			pipeline_compile_options.traversableGraphFlags =
				OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS;
			pipeline_compile_options.usesPrimitiveTypeFlags = spheres ? OPTIX_PRIMITIVE_TYPE_FLAGS_SPHERE : OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE;

			// Our device code uses 3 payload registers (r,g,b output value)
			pipeline_compile_options.numPayloadValues = 3;

			// This is the name of the param struct variable in our device code
			pipeline_compile_options.pipelineLaunchParamsVariableName = "params_data";

			OPTIX_CHECK_THROW_LOG(optixModuleCreateFromPTX(
				optix,
				&module_compile_options,
				&pipeline_compile_options,
				data,
				size,
				log,
				&sizeof_log,
				&optix_module
			));

			// Program groups
			OptixProgramGroup raygen_prog_group   = nullptr;
			OptixProgramGroup miss_prog_group     = nullptr;
			OptixProgramGroup hitgroup_prog_group = nullptr;
			{
				OptixProgramGroupOptions program_group_options   = {}; // Initialize to zeros

				OptixProgramGroupDesc raygen_prog_group_desc    = {}; //
				raygen_prog_group_desc.kind                     = OPTIX_PROGRAM_GROUP_KIND_RAYGEN;
				raygen_prog_group_desc.raygen.module            = optix_module;
				raygen_prog_group_desc.raygen.entryFunctionName = "__raygen__rg";
				OPTIX_CHECK_THROW_LOG(optixProgramGroupCreate(
					optix,
					&raygen_prog_group_desc,
					1,   // num program groups
					&program_group_options,
					log,
					&sizeof_log,
					&raygen_prog_group
				));

				OptixProgramGroupDesc miss_prog_group_desc  = {};
				miss_prog_group_desc.kind                   = OPTIX_PROGRAM_GROUP_KIND_MISS;
				miss_prog_group_desc.miss.module            = optix_module;
				miss_prog_group_desc.miss.entryFunctionName = "__miss__ms";
				OPTIX_CHECK_THROW_LOG(optixProgramGroupCreate(
					optix,
					&miss_prog_group_desc,
					1,   // num program groups
					&program_group_options,
					log,
					&sizeof_log,
					&miss_prog_group
				));

				OptixProgramGroupDesc hitgroup_prog_group_desc = {};
				hitgroup_prog_group_desc.kind                         = OPTIX_PROGRAM_GROUP_KIND_HITGROUP;
				hitgroup_prog_group_desc.hitgroup.moduleCH            = optix_module;
				hitgroup_prog_group_desc.hitgroup.entryFunctionNameCH = "__closesthit__ch";

				if (spheres) {
					OptixBuiltinISOptions is_options = {};
					is_options.builtinISModuleType = OPTIX_PRIMITIVE_TYPE_SPHERE;
					is_options.buildFlags = OPTIX_BUILD_FLAG_NONE;
					is_options.usesMotionBlur = false;
					OPTIX_CHECK_THROW(optixBuiltinISModuleGet(
						optix,
						&module_compile_options,
						&pipeline_compile_options,
						&is_options,
						&hitgroup_prog_group_desc.hitgroup.moduleIS
					));
				}

				OPTIX_CHECK_THROW_LOG(optixProgramGroupCreate(
					optix,
					&hitgroup_prog_group_desc,
					1,   // num program groups
					&program_group_options,
					log,
					&sizeof_log,
					&hitgroup_prog_group
				));
			}

			// Linking
			{
				const uint32_t max_trace_depth = 1;
				OptixProgramGroup program_groups[] = { raygen_prog_group, miss_prog_group, hitgroup_prog_group };

				OptixPipelineLinkOptions pipeline_link_options = {};
				pipeline_link_options.maxTraceDepth = max_trace_depth;
				pipeline_link_options.debugLevel    = OPTIX_COMPILE_DEBUG_LEVEL_DEFAULT;

				OPTIX_CHECK_THROW_LOG(optixPipelineCreate(
					optix,
					&pipeline_compile_options,
					&pipeline_link_options,
					program_groups,
					sizeof(program_groups) / sizeof(program_groups[0]),
					log,
					&sizeof_log,
					&m_pipeline
				));

				OptixStackSizes stack_sizes = {};
				for (auto& prog_group : program_groups) {
					OPTIX_CHECK_THROW(optixUtilAccumulateStackSizes(prog_group, &stack_sizes));
				}

				uint32_t direct_callable_stack_size_from_traversal;
				uint32_t direct_callable_stack_size_from_state;
				uint32_t continuation_stack_size;
				OPTIX_CHECK_THROW(optixUtilComputeStackSizes(
					&stack_sizes, max_trace_depth,
					0,  // maxCCDepth
					0,  // maxDCDEpth
					&direct_callable_stack_size_from_traversal,
					&direct_callable_stack_size_from_state, &continuation_stack_size
				));
				OPTIX_CHECK_THROW(optixPipelineSetStackSize(
					m_pipeline, direct_callable_stack_size_from_traversal,
					direct_callable_stack_size_from_state, continuation_stack_size,
					1  // maxTraversableDepth
				));
			}

			// Shader binding table
			{
				CUdeviceptr raygen_record;
				const size_t raygen_record_size = sizeof(SbtRecord<typename T::RayGenData>);
				CUDA_CHECK_THROW(cudaMalloc(reinterpret_cast<void**>(&raygen_record), raygen_record_size));
				SbtRecord<typename T::RayGenData> rg_sbt;
				OPTIX_CHECK_THROW(optixSbtRecordPackHeader(raygen_prog_group, &rg_sbt));
				CUDA_CHECK_THROW(cudaMemcpy(
					reinterpret_cast<void*>(raygen_record),
					&rg_sbt,
					raygen_record_size,
					cudaMemcpyHostToDevice
				));

				CUdeviceptr miss_record;
				size_t miss_record_size = sizeof(SbtRecord<typename T::MissData>);
				CUDA_CHECK_THROW(cudaMalloc(reinterpret_cast<void**>(&miss_record), miss_record_size));
				SbtRecord<typename T::MissData> ms_sbt;
				OPTIX_CHECK_THROW(optixSbtRecordPackHeader(miss_prog_group, &ms_sbt));
				CUDA_CHECK_THROW(cudaMemcpy(
					reinterpret_cast<void*>(miss_record),
					&ms_sbt,
					miss_record_size,
					cudaMemcpyHostToDevice
				));

				CUdeviceptr hitgroup_record;
				size_t hitgroup_record_size = sizeof(SbtRecord<typename T::HitGroupData>);
				CUDA_CHECK_THROW(cudaMalloc(reinterpret_cast<void**>(&hitgroup_record), hitgroup_record_size));
				SbtRecord<typename T::HitGroupData> hg_sbt;
				OPTIX_CHECK_THROW(optixSbtRecordPackHeader(hitgroup_prog_group, &hg_sbt));
				CUDA_CHECK_THROW(cudaMemcpy(
					reinterpret_cast<void*>(hitgroup_record),
					&hg_sbt,
					hitgroup_record_size,
					cudaMemcpyHostToDevice
				));

				m_sbt.raygenRecord                = raygen_record;
				m_sbt.missRecordBase              = miss_record;
				m_sbt.missRecordStrideInBytes     = sizeof(SbtRecord<typename T::MissData>);
				m_sbt.missRecordCount             = 1;
				m_sbt.hitgroupRecordBase          = hitgroup_record;
				m_sbt.hitgroupRecordStrideInBytes = sizeof(SbtRecord<typename T::HitGroupData>);
				m_sbt.hitgroupRecordCount         = 1;
			}
		}

		void invoke(const typename T::Params& params, const uint3& dim, cudaStream_t stream) {
			CUDA_CHECK_THROW(cudaMemcpyAsync(m_params_gpu.data(), &params, sizeof(typename T::Params), cudaMemcpyHostToDevice, stream));
			OPTIX_CHECK_THROW(optixLaunch(m_pipeline, stream, (CUdeviceptr)(uintptr_t)m_params_gpu.data(), sizeof(typename T::Params), &m_sbt, dim.x, dim.y, dim.z));
		}

	private:
		OptixShaderBindingTable m_sbt = {};
		OptixPipeline m_pipeline = nullptr;
		GPUMemory<typename T::Params> m_params_gpu = GPUMemory<typename T::Params>(1);
	};

	class Gas {
	public:
		Gas(const GPUMemory<Triangle>& triangles, OptixDeviceContext optix, cudaStream_t stream) {
			// Specify options for the build. We use default options for simplicity.
			OptixAccelBuildOptions accel_options = {};
			accel_options.buildFlags = OPTIX_BUILD_FLAG_NONE;
			accel_options.operation = OPTIX_BUILD_OPERATION_BUILD;

			// Populate the build input struct with our triangle data as well as
			// information about the sizes and types of our data
			const uint32_t triangle_input_flags[1] = { OPTIX_GEOMETRY_FLAG_NONE };
			OptixBuildInput triangle_input = {};

			CUdeviceptr d_triangles = (CUdeviceptr)(uintptr_t)triangles.data();

			triangle_input.type = OPTIX_BUILD_INPUT_TYPE_TRIANGLES;
			triangle_input.triangleArray.vertexFormat = OPTIX_VERTEX_FORMAT_FLOAT3;
			triangle_input.triangleArray.numVertices = (uint32_t)triangles.size()*3;
			triangle_input.triangleArray.vertexBuffers = &d_triangles;
			triangle_input.triangleArray.flags = triangle_input_flags;
			triangle_input.triangleArray.numSbtRecords = 1;

			// Query OptiX for the memory requirements for our GAS
			OptixAccelBufferSizes gas_buffer_sizes;
			OPTIX_CHECK_THROW(optixAccelComputeMemoryUsage(optix, &accel_options, &triangle_input, 1, &gas_buffer_sizes));

			// Allocate device memory for the scratch space buffer as well
			// as the GAS itself
			GPUMemory<char> gas_tmp_buffer{gas_buffer_sizes.tempSizeInBytes};
			m_gas_gpu_buffer.resize(gas_buffer_sizes.outputSizeInBytes);

			OPTIX_CHECK_THROW(optixAccelBuild(
				optix,
				stream,
				&accel_options,
				&triangle_input,
				1,           // num build inputs
				(CUdeviceptr)(uintptr_t)gas_tmp_buffer.data(),
				gas_buffer_sizes.tempSizeInBytes,
				(CUdeviceptr)(uintptr_t)m_gas_gpu_buffer.data(),
				gas_buffer_sizes.outputSizeInBytes,
				&m_gas_handle, // Output handle to the struct
				nullptr,       // emitted property list
				0              // num emitted properties
			));
		}

		OptixTraversableHandle handle() const {
			return m_gas_handle;
		}

	private:
		OptixTraversableHandle m_gas_handle;
		GPUMemory<char> m_gas_gpu_buffer;
	};
}
#endif //NGP_OPTIX

__global__ void signed_distance_watertight_kernel(uint32_t n_elements, const vec3* __restrict__ positions, const TriangleBvhNode* __restrict__ bvhnodes, const Triangle* __restrict__ triangles, float* __restrict__ distances, bool use_existing_distances_as_upper_bounds = false);
__global__ void signed_distance_raystab_kernel(uint32_t n_elements, const vec3* __restrict__ positions, const TriangleBvhNode* __restrict__ bvhnodes, const Triangle* __restrict__ triangles, float* __restrict__ distances, bool use_existing_distances_as_upper_bounds = false);
__global__ void unsigned_distance_kernel(uint32_t n_elements, const vec3* __restrict__ positions, const TriangleBvhNode* __restrict__ bvhnodes, const Triangle* __restrict__ triangles, float* __restrict__ distances, bool use_existing_distances_as_upper_bounds = false);
__global__ void raytrace_kernel(uint32_t n_elements, vec3* __restrict__ positions, vec3* __restrict__ directions, const TriangleBvhNode* __restrict__ nodes, const Triangle* __restrict__ triangles);

struct DistAndIdx {
	float dist;
	uint32_t idx;

	// Sort in descending order!
	__host__ __device__ bool operator<(const DistAndIdx& other) {
		return dist < other.dist;
	}
};

template <typename T>
__host__ __device__ void inline compare_and_swap(T& t1, T& t2) {
	if (t1 < t2) {
		T tmp{t1}; t1 = t2; t2 = tmp;
	}
}

// Sorting networks from http://users.telenet.be/bertdobbelaere/SorterHunter/sorting_networks.html#N4L5D3
template <uint32_t N, typename T>
__host__ __device__ void sorting_network(T values[N]) {
	static_assert(N <= 8, "Sorting networks are only implemented up to N==8");
	if (N <= 1) {
		return;
	} else if (N == 2) {
		compare_and_swap(values[0], values[1]);
	} else if (N == 3) {
		compare_and_swap(values[0], values[2]);
		compare_and_swap(values[0], values[1]);
		compare_and_swap(values[1], values[2]);
	} else if (N == 4) {
		compare_and_swap(values[0], values[2]);
		compare_and_swap(values[1], values[3]);
		compare_and_swap(values[0], values[1]);
		compare_and_swap(values[2], values[3]);
		compare_and_swap(values[1], values[2]);
	} else if (N == 5) {
		compare_and_swap(values[0], values[3]);
		compare_and_swap(values[1], values[4]);

		compare_and_swap(values[0], values[2]);
		compare_and_swap(values[1], values[3]);

		compare_and_swap(values[0], values[1]);
		compare_and_swap(values[2], values[4]);

		compare_and_swap(values[1], values[2]);
		compare_and_swap(values[3], values[4]);

		compare_and_swap(values[2], values[3]);
	} else if (N == 6) {
		compare_and_swap(values[0], values[5]);
		compare_and_swap(values[1], values[3]);
		compare_and_swap(values[2], values[4]);

		compare_and_swap(values[1], values[2]);
		compare_and_swap(values[3], values[4]);

		compare_and_swap(values[0], values[3]);
		compare_and_swap(values[2], values[5]);

		compare_and_swap(values[0], values[1]);
		compare_and_swap(values[2], values[3]);
		compare_and_swap(values[4], values[5]);

		compare_and_swap(values[1], values[2]);
		compare_and_swap(values[3], values[4]);
	} else if (N == 7) {
		compare_and_swap(values[0], values[6]);
		compare_and_swap(values[2], values[3]);
		compare_and_swap(values[4], values[5]);

		compare_and_swap(values[0], values[2]);
		compare_and_swap(values[1], values[4]);
		compare_and_swap(values[3], values[6]);

		compare_and_swap(values[0], values[1]);
		compare_and_swap(values[2], values[5]);
		compare_and_swap(values[3], values[4]);

		compare_and_swap(values[1], values[2]);
		compare_and_swap(values[4], values[6]);

		compare_and_swap(values[2], values[3]);
		compare_and_swap(values[4], values[5]);

		compare_and_swap(values[1], values[2]);
		compare_and_swap(values[3], values[4]);
		compare_and_swap(values[5], values[6]);
	} else if (N == 8) {
		compare_and_swap(values[0], values[2]);
		compare_and_swap(values[1], values[3]);
		compare_and_swap(values[4], values[6]);
		compare_and_swap(values[5], values[7]);

		compare_and_swap(values[0], values[4]);
		compare_and_swap(values[1], values[5]);
		compare_and_swap(values[2], values[6]);
		compare_and_swap(values[3], values[7]);

		compare_and_swap(values[0], values[1]);
		compare_and_swap(values[2], values[3]);
		compare_and_swap(values[4], values[5]);
		compare_and_swap(values[6], values[7]);

		compare_and_swap(values[2], values[4]);
		compare_and_swap(values[3], values[5]);

		compare_and_swap(values[1], values[4]);
		compare_and_swap(values[3], values[6]);

		compare_and_swap(values[1], values[2]);
		compare_and_swap(values[3], values[4]);
		compare_and_swap(values[5], values[6]);
	}
}

template <uint32_t BRANCHING_FACTOR>
class TriangleBvhWithBranchingFactor : public TriangleBvh {
public:
	__host__ __device__ static std::pair<int, float> ray_intersect(const vec3& ro, const vec3& rd, const TriangleBvhNode* __restrict__ bvhnodes, const Triangle* __restrict__ triangles) {
		FixedIntStack query_stack;
		query_stack.push(0);

		float mint = MAX_DIST;
		int shortest_idx = -1;

		while (!query_stack.empty()) {
			int idx = query_stack.pop();

			const TriangleBvhNode& node = bvhnodes[idx];

			if (node.left_idx < 0) {
				int end = -node.right_idx-1;
				for (int i = -node.left_idx-1; i < end; ++i) {
					float t = triangles[i].ray_intersect(ro, rd);
					if (t < mint) {
						mint = t;
						shortest_idx = i;
					}
				}
			} else {
				DistAndIdx children[BRANCHING_FACTOR];

				uint32_t first_child = node.left_idx;

				NGP_PRAGMA_UNROLL
				for (uint32_t i = 0; i < BRANCHING_FACTOR; ++i) {
					children[i] = {bvhnodes[i+first_child].bb.ray_intersect(ro, rd).x, i+first_child};
				}

				sorting_network<BRANCHING_FACTOR>(children);

				NGP_PRAGMA_UNROLL
				for (uint32_t i = 0; i < BRANCHING_FACTOR; ++i) {
					if (children[i].dist < mint) {
						query_stack.push(children[i].idx);
					}
				}
			}
		}

		return {shortest_idx, mint};
	}

	__host__ __device__ static std::pair<int, float> closest_triangle(const vec3& point, const TriangleBvhNode* __restrict__ bvhnodes, const Triangle* __restrict__ triangles, float max_distance_sq) {
		FixedIntStack query_stack;
		query_stack.push(0);

		float shortest_distance_sq = max_distance_sq;
		int shortest_idx = -1;

		while (!query_stack.empty()) {
			int idx = query_stack.pop();

			const TriangleBvhNode& node = bvhnodes[idx];

			if (node.left_idx < 0) {
				int end = -node.right_idx-1;
				for (int i = -node.left_idx-1; i < end; ++i) {
					float dist_sq = triangles[i].distance_sq(point);
					if (dist_sq <= shortest_distance_sq) {
						shortest_distance_sq = dist_sq;
						shortest_idx = i;
					}
				}
			} else {
				DistAndIdx children[BRANCHING_FACTOR];

				uint32_t first_child = node.left_idx;

				NGP_PRAGMA_UNROLL
				for (uint32_t i = 0; i < BRANCHING_FACTOR; ++i) {
					children[i] = {bvhnodes[i+first_child].bb.distance_sq(point), i+first_child};
				}

				sorting_network<BRANCHING_FACTOR>(children);

				NGP_PRAGMA_UNROLL
				for (uint32_t i = 0; i < BRANCHING_FACTOR; ++i) {
					if (children[i].dist <= shortest_distance_sq) {
						query_stack.push(children[i].idx);
					}
				}
			}
		}

		if (shortest_idx == -1) {
			// printf("No closest triangle found. This must be a bug! %d\n", BRANCHING_FACTOR);
			shortest_idx = 0;
			shortest_distance_sq = 0.0f;
		}

		return {shortest_idx, std::sqrt(shortest_distance_sq)};
	}

	// Assumes that "point" is a location on a triangle
	__host__ __device__ static vec3 avg_normal_around_point(const vec3& point, const TriangleBvhNode* __restrict__ bvhnodes, const Triangle* __restrict__ triangles) {
		FixedIntStack query_stack;
		query_stack.push(0);

		static constexpr float EPSILON = 1e-12f;

		uint32_t n_tris = 0;
		vec3 normal_weighted_sum = vec3(0.0f);
		vec3 normal_sum = vec3(0.0f);

		while (!query_stack.empty()) {
			int idx = query_stack.pop();

			const TriangleBvhNode& node = bvhnodes[idx];

			if (node.left_idx < 0) {
				int end = -node.right_idx-1;
				for (int i = -node.left_idx-1; i < end; ++i) {
					const Triangle& tri = triangles[i];
					if (tri.distance_sq(point) < EPSILON) {
						vec3 n = tri.normal();
						normal_sum += n;
						normal_weighted_sum += tri.angle_at_pos(point) * n;
						++n_tris;
					}
				}
			} else {
				uint32_t first_child = node.left_idx;

				NGP_PRAGMA_UNROLL
				for (uint32_t i = 0; i < BRANCHING_FACTOR; ++i) {
					if (bvhnodes[i+first_child].bb.distance_sq(point) < EPSILON) {
						query_stack.push(i+first_child);
					}
				}
			}
		}

		if (n_tris < 3) {
			return normalize(normal_sum);
		} else {
			return normalize(normal_weighted_sum);
		}
	}

	__host__ __device__ static float unsigned_distance(const vec3& point, const TriangleBvhNode* __restrict__ bvhnodes, const Triangle* __restrict__ triangles, float max_distance_sq) {
		return closest_triangle(point, bvhnodes, triangles, max_distance_sq).second;
	}

	__host__ __device__ static float signed_distance_watertight(const vec3& point, const TriangleBvhNode* __restrict__ bvhnodes, const Triangle* __restrict__ triangles, float max_distance_sq) {
		auto p = closest_triangle(point, bvhnodes, triangles, max_distance_sq);

		const Triangle& tri = triangles[p.first];
		vec3 closest_point = tri.closest_point(point);
		vec3 avg_normal = avg_normal_around_point(closest_point, bvhnodes, triangles);

		return copysign(avg_normal == vec3(0.0f) ? 0.0f : p.second, dot(avg_normal, point - closest_point));
	}

	__host__ __device__ static float signed_distance_raystab(const vec3& point, const TriangleBvhNode* __restrict__ bvhnodes, const Triangle* __restrict__ triangles, float max_distance_sq, default_rng_t rng={}) {
		float distance = unsigned_distance(point, bvhnodes, triangles, max_distance_sq);

		vec2 offset = random_val_2d(rng);

		static constexpr uint32_t N_STAB_RAYS = 32;
		for (uint32_t i = 0; i < N_STAB_RAYS; ++i) {
			// Use a Fibonacci lattice (with random offset) to regularly
			// distribute the stab rays over the sphere.
			vec3 d = fibonacci_dir<N_STAB_RAYS>(i, offset);

			// If any of the stab rays goes outside the mesh, the SDF is positive.
			if (ray_intersect(point, d, bvhnodes, triangles).first < 0) {
				return distance;
			}
		}

		return -distance;
	}

	// Assumes that "point" is a location on a triangle
	vec3 avg_normal_around_point(const vec3& point, const Triangle* __restrict__ triangles) const {
		return avg_normal_around_point(point, m_nodes.data(), triangles);
	}

	float signed_distance(EMeshSdfMode mode, const vec3& point, const std::vector<Triangle>& triangles) const {
		if (mode == EMeshSdfMode::Watertight) {
			return signed_distance_watertight(point, m_nodes.data(), triangles.data(), MAX_DIST*MAX_DIST);
		} else {
			return signed_distance_raystab(point, m_nodes.data(), triangles.data(), MAX_DIST*MAX_DIST);
		}
	}

	void signed_distance_gpu(uint32_t n_elements, EMeshSdfMode mode, const vec3* gpu_positions, float* gpu_distances, const Triangle* gpu_triangles, bool use_existing_distances_as_upper_bounds, cudaStream_t stream) override {
		if (mode == EMeshSdfMode::Watertight) {
			linear_kernel(signed_distance_watertight_kernel, 0, stream,
				n_elements,
				gpu_positions,
				m_nodes_gpu.data(),
				gpu_triangles,
				gpu_distances,
				use_existing_distances_as_upper_bounds
			);
		} else {
#ifdef NGP_OPTIX
			if (m_optix.available) {
				linear_kernel(unsigned_distance_kernel, 0, stream,
					n_elements,
					gpu_positions,
					m_nodes_gpu.data(),
					gpu_triangles,
					gpu_distances,
					use_existing_distances_as_upper_bounds
				);

				if (mode == EMeshSdfMode::Raystab) {
					m_optix.raystab->invoke({gpu_positions, gpu_distances, m_optix.gas->handle()}, {n_elements, 1, 1}, stream);
				} else if (mode == EMeshSdfMode::PathEscape) {
					m_optix.pathescape->invoke({gpu_positions, gpu_triangles, gpu_distances, m_optix.gas->handle()}, {n_elements, 1, 1}, stream);
				}
			} else
#endif //NGP_OPTIX
			{
				if (mode == EMeshSdfMode::Raystab) {
					linear_kernel(signed_distance_raystab_kernel, 0, stream,
						n_elements,
						gpu_positions,
						m_nodes_gpu.data(),
						gpu_triangles,
						gpu_distances,
						use_existing_distances_as_upper_bounds
					);
				} else if (mode == EMeshSdfMode::PathEscape) {
					throw std::runtime_error{"TriangleBvh: EMeshSdfMode::PathEscape is only supported with OptiX enabled."};
				}
			}
		}
	}

	void ray_trace_gpu(uint32_t n_elements, vec3* gpu_positions, vec3* gpu_directions, const Triangle* gpu_triangles, cudaStream_t stream) override {
#ifdef NGP_OPTIX
		if (m_optix.available) {
			m_optix.raytrace->invoke({gpu_positions, gpu_directions, gpu_triangles, m_optix.gas->handle()}, {n_elements, 1, 1}, stream);
		} else
#endif //NGP_OPTIX
		{
			linear_kernel(raytrace_kernel, 0, stream,
				n_elements,
				gpu_positions,
				gpu_directions,
				m_nodes_gpu.data(),
				gpu_triangles
			);
		}
	}

	bool touches_triangle(const BoundingBox& bb, const TriangleBvhNode& node, const Triangle* __restrict__ triangles) const {
		if (!node.bb.intersects(bb)) {
			return false;
		}

		if (node.left_idx < 0) {
			// Touches triangle leaves?
			int end = -node.right_idx-1;
			for (int i = -node.left_idx-1; i < end; ++i) {
				if (bb.intersects(triangles[i])) {
					return true;
				}
			}
		} else {
			// Touches children?
			int child_idx = node.left_idx;
			for (int i = 0; i < BRANCHING_FACTOR; ++i) {
				if (touches_triangle(bb, m_nodes[i+child_idx], triangles)) {
					return true;
				}
			}
		}

		return false;
	}

	bool touches_triangle(const BoundingBox& bb, const Triangle* __restrict__ triangles) const override {
		return touches_triangle(bb, m_nodes.front(), triangles);
	}

	void build(std::vector<Triangle>& triangles, uint32_t n_primitives_per_leaf) override {
		m_nodes.clear();

		// Root
		m_nodes.emplace_back();
		m_nodes.front().bb = BoundingBox(triangles.data(), triangles.data() + triangles.size());

		struct BuildNode {
			int node_idx;
			std::vector<Triangle>::iterator begin;
			std::vector<Triangle>::iterator end;
		};

		std::stack<BuildNode> build_stack;
		build_stack.push({0, std::begin(triangles), std::end(triangles)});

		while (!build_stack.empty()) {
			const BuildNode& curr = build_stack.top();
			size_t node_idx = curr.node_idx;

			std::array<BuildNode, BRANCHING_FACTOR> children;
			children[0].begin = curr.begin;
			children[0].end = curr.end;

			build_stack.pop();

			// Partition the triangles into the children
			int n_children = 1;
			while (n_children < BRANCHING_FACTOR) {
				for (int i = n_children - 1; i >= 0; --i) {
					auto& child = children[i];

					// Choose axis with maximum standard deviation
					vec3 mean = vec3(0.0f);
					for (auto it = child.begin; it != child.end; ++it) {
						mean += it->centroid();
					}
					mean /= (float)std::distance(child.begin, child.end);

					vec3 var = vec3(0.0f);
					for (auto it = child.begin; it != child.end; ++it) {
						vec3 diff = it->centroid() - mean;
						var += diff * diff;
					}
					var /= (float)std::distance(child.begin, child.end);

					float max_val = max(var);
					int axis = var.x == max_val ? 0 : (var.y == max_val ? 1 : 2);

					auto m = child.begin + std::distance(child.begin, child.end)/2;
					std::nth_element(child.begin, m, child.end, [&](const Triangle& tri1, const Triangle& tri2) { return tri1.centroid(axis) < tri2.centroid(axis); });

					children[i*2].begin = children[i].begin;
					children[i*2+1].end = children[i].end;
					children[i*2].end = children[i*2+1].begin = m;
				}

				n_children *= 2;
			}

			// Create next build nodes
			m_nodes[node_idx].left_idx = (int)m_nodes.size();
			for (uint32_t i = 0; i < BRANCHING_FACTOR; ++i) {
				auto& child = children[i];
				assert(child.begin != child.end);
				child.node_idx = (int)m_nodes.size();

				m_nodes.emplace_back();
				m_nodes.back().bb = BoundingBox(&*child.begin, &*child.end);

				if (std::distance(child.begin, child.end) <= n_primitives_per_leaf) {
					m_nodes.back().left_idx = -(int)std::distance(std::begin(triangles), child.begin)-1;
					m_nodes.back().right_idx = -(int)std::distance(std::begin(triangles), child.end)-1;
				} else {
					build_stack.push(child);
				}
			}
			m_nodes[node_idx].right_idx = (int)m_nodes.size();
		}

		m_nodes_gpu.resize_and_copy_from_host(m_nodes);

		tlog::success() << "Built TriangleBvh: nodes=" << m_nodes.size();
	}

	void build_optix(const GPUMemory<Triangle>& triangles, cudaStream_t stream) override {
#ifdef NGP_OPTIX
		m_optix.available = optix::initialize();
		if (m_optix.available) {
			m_optix.gas = std::make_unique<optix::Gas>(triangles, g_optix, stream);
			m_optix.raystab = std::make_unique<optix::Program<Raystab>>((const char*)optix_ptx::raystab_ptx, sizeof(optix_ptx::raystab_ptx), g_optix);
			m_optix.raytrace = std::make_unique<optix::Program<Raytrace>>((const char*)optix_ptx::raytrace_ptx, sizeof(optix_ptx::raytrace_ptx), g_optix);
			m_optix.pathescape = std::make_unique<optix::Program<PathEscape>>((const char*)optix_ptx::pathescape_ptx, sizeof(optix_ptx::pathescape_ptx), g_optix);
			tlog::success() << "Built OptiX GAS and shaders";
		} else {
			tlog::warning() << "Falling back to slower TriangleBVH::ray_intersect.";
		}
#else //NGP_OPTIX
		tlog::warning() << "OptiX was not built. Falling back to slower TriangleBVH::ray_intersect.";
#endif //NGP_OPTIX
	}

	TriangleBvhWithBranchingFactor() {}

private:
#ifdef NGP_OPTIX
	struct {
		std::unique_ptr<optix::Gas> gas;
		std::unique_ptr<optix::Program<Raystab>> raystab;
		std::unique_ptr<optix::Program<Raytrace>> raytrace;
		std::unique_ptr<optix::Program<PathEscape>> pathescape;
		bool available = false;
	} m_optix;
#endif //NGP_OPTIX
};

using TriangleBvh4 = TriangleBvhWithBranchingFactor<4>;

std::unique_ptr<TriangleBvh> TriangleBvh::make() {
	return std::unique_ptr<TriangleBvh>(new TriangleBvh4());
}

__global__ void signed_distance_watertight_kernel(uint32_t n_elements,
	const vec3* __restrict__ positions,
	const TriangleBvhNode* __restrict__ bvhnodes,
	const Triangle* __restrict__ triangles,
	float* __restrict__ distances,
	bool use_existing_distances_as_upper_bounds
) {
	uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
	if (i >= n_elements) return;

	float max_distance = use_existing_distances_as_upper_bounds ? distances[i] : MAX_DIST;
	distances[i] = TriangleBvh4::signed_distance_watertight(positions[i], bvhnodes, triangles, max_distance*max_distance);
}

__global__ void signed_distance_raystab_kernel(
	uint32_t n_elements,
	const vec3* __restrict__ positions,
	const TriangleBvhNode* __restrict__ bvhnodes,
	const Triangle* __restrict__ triangles,
	float* __restrict__ distances,
	bool use_existing_distances_as_upper_bounds
) {
	uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
	if (i >= n_elements) return;

	float max_distance = use_existing_distances_as_upper_bounds ? distances[i] : MAX_DIST;
	default_rng_t rng;
	rng.advance(i * 2);

	distances[i] = TriangleBvh4::signed_distance_raystab(positions[i], bvhnodes, triangles, max_distance*max_distance, rng);
}

__global__ void unsigned_distance_kernel(uint32_t n_elements,
	const vec3* __restrict__ positions,
	const TriangleBvhNode* __restrict__ bvhnodes,
	const Triangle* __restrict__ triangles,
	float* __restrict__ distances,
	bool use_existing_distances_as_upper_bounds
) {
	uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
	if (i >= n_elements) return;

	float max_distance = use_existing_distances_as_upper_bounds ? distances[i] : MAX_DIST;
	distances[i] = TriangleBvh4::unsigned_distance(positions[i], bvhnodes, triangles, max_distance*max_distance);
}

__global__ void raytrace_kernel(uint32_t n_elements, vec3* __restrict__ positions, vec3* __restrict__ directions, const TriangleBvhNode* __restrict__ nodes, const Triangle* __restrict__ triangles) {
	uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
	if (i >= n_elements) return;

	auto pos = positions[i];
	auto dir = directions[i];

	auto p = TriangleBvh4::ray_intersect(pos, dir, nodes, triangles);
	positions[i] = pos + p.second * dir;

	if (p.first >= 0) {
		directions[i] = triangles[p.first].normal();
	}
}

}


